# -*- coding: utf-8 -*-

import decimal
from datetime import datetime

from sqlalchemy.engine import create_engine
from sqlalchemy.schema import MetaData

from bopress.cache import Cache
from bopress.log import Logger
from bopress.utils import Utils

__author__ = 'yezang'


class ShortColumn(object):
    name = ""
    nullable = False
    primary_key = False
    data_type = ""
    length = 0
    ext_length = 0
    foreign_key = False


class DataBase(object):
    __tables_meta__ = None

    @staticmethod
    def build_dsn(db_type="sqlite", host="127.0.0.1", port=3306, user="root", pwd="root", db_name="test"):
        if db_type == "sqlite":
            return "sqlite:///%s" % db_name
        elif db_type == "pymysql":
            return "mysql+pymysql://%s%s%s%s%s" % \
                   (user, ":%s" % pwd, "@%s" % host, ":%s" % port, "/%s" % db_name)
        elif db_type == "pymssql":
            return "mssql+pymssql://%s%s%s%s%s" % \
                   (user, ":%s" % pwd, "@%s" % host, ":%s" % port, "/%s" % db_name)
        elif db_type == "psycopg2":
            return "postgresql+psycopg2://%s%s%s%s%s" % \
                   (user, ":%s" % pwd, "@%s" % host, ":%s" % port, "/%s" % db_name)
        elif db_type == "cx_oracle":
            return "oracle+cx_oracle://%s%s%s%s%s" % \
                   (user, ":%s" % pwd, "@%s" % host, ":%s" % port, "/%s" % db_name)
        return ""

    @staticmethod
    def connect(dsn):
        if not dsn:
            DataBase.__tables_meta__ = None
            return DataBase.__tables_meta__
        try:
            engine = create_engine(dsn)
            DataBase.__tables_meta__ = MetaData(engine)
            DataBase.__tables_meta__.reflect(views=True)
        except Exception as e:
            Logger.exception(e)
            return False

        return DataBase.__tables_meta__

    @staticmethod
    def tables():
        if DataBase.__tables_meta__:
            return list(DataBase.__tables_meta__.tables.keys())
        return list()

    @staticmethod
    def columns(table_name):
        cols = list()
        if not table_name:
            return cols
        if not DataBase.__tables_meta__:
            return cols
        for c in DataBase.__tables_meta__.tables[table_name].columns:
            cols.append((c.name, c.type.python_type.__name__.lower()))
        return cols

    @staticmethod
    def column(table_name, column_name):
        if not DataBase.__tables_meta__:
            return None
        for c in DataBase.__tables_meta__.tables[table_name].columns:
            if c.name == column_name:
                return c
        return None

    @staticmethod
    def pk(table_name):
        if not table_name:
            return ""
        if not DataBase.__tables_meta__:
            return ""
        index = 0
        b = ""
        for c in DataBase.__tables_meta__.tables[table_name].columns:
            if index == 0:
                b = c.name
            index += 1
            if c.primary_key:
                b = c.name
                break
        return b

    @staticmethod
    def columns_meta(table_name):
        if not DataBase.__tables_meta__:
            return None
        cols = DataBase.__tables_meta__.tables[table_name].columns
        cols_arr = list()
        for c in cols:
            col_info = ShortColumn()
            col_info.name = c.name
            col_info.primary_key = c.primary_key
            col_info.nullable = c.nullable
            if len(c.foreign_keys) > 0:
                col_info.foreign_key = True
            data_type = c.type.python_type
            if data_type is int:
                col_info.data_type = "number"
                col_info.length = c.type.display_width
            elif data_type is decimal.Decimal:
                col_info.data_type = "number"
                col_info.length = c.type.precision
                col_info.ext_length = c.type.scale
            elif data_type is float:
                col_info.data_type = "number"
                col_info.length = 0
                col_info.ext_length = 0
            elif data_type is datetime:
                col_info.data_type = "datetime"
                col_info.length = 0
                col_info.ext_length = 0
            else:
                col_info.data_type = "text"
                col_info.length = 0
                col_info.ext_length = 0
            cols_arr.append(col_info)
        return cols_arr

    @staticmethod
    def request(handler):
        action = handler.get_argument("action", "")
        if action:
            if action == "tables":
                cnn_name = handler.get_argument("cnn_name", "")
                r = DataBase.connect(cnn_name)
                if not r:
                    handler.render_json(msg="连接错误! 请检查参数是否设置正确.", success=False)
                else:
                    tables = DataBase.tables()
                    handler.render_json(tables)
            elif action == "columns":
                table_name = handler.get_argument("table_name", "")
                handler.render_json(DataBase.columns(table_name))
            elif action == "dsn_get":
                dsn = Cache.data().get("bo_code_dsn", dict())
                handler.render_json(dsn)
            elif action == "dsn_del":
                db_conn_name = handler.get_argument("db_conn_name", "")
                dsn = Cache.data().get("bo_code_dsn", None)
                if dsn:
                    del dsn[db_conn_name]
                Cache.data().set("bo_code_dsn", dsn)
                handler.render_json(dsn)
            elif action == "dsn_save":
                db_conn_name = handler.get_argument("db_conn_name", "")
                db_type = handler.get_argument("db_type", "")
                db_host = handler.get_argument("db_host", "")
                db_port = handler.get_argument("db_port", "")
                db_user = handler.get_argument("db_user", "")
                db_pwd = handler.get_argument("db_pwd", "")
                db_name = handler.get_argument("db_name", "")
                dsn = Cache.data().get("bo_code_dsn", dict())
                dsn[db_conn_name] = DataBase.build_dsn(db_type, db_host, db_port, db_user, db_pwd, db_name)
                Cache.data().set("bo_code_dsn", dsn)
                handler.render_json(success=True)
            elif action == "bo_code_generation":
                table_name = handler.get_argument("current_table_name", "")
                exclude_columns = handler.get_arguments("exclude_columns")
                search_columns = handler.get_arguments("search_columns")
                order_columns = handler.get_arguments("order_columns")
                freeze_columns = handler.get_arguments("freeze_columns")
                template_file = handler.get_argument("template_file", "")

                cols = DataBase.columns_meta(table_name)
                pk = DataBase.pk(table_name)
                params = dict()
                for p in handler.request.arguments.keys():
                    if p.startswith("opt_"):
                        k = p.replace("opt_", "")
                        params[k] = handler.get_argument(p, "")
                # template_folder = handler.application.settings.get("template_path")
                # tpl_file_path = os.path.join(template_folder, template_file)
                # from tornado.template import Template
                # t = Template(Utils.render_file(tpl_file_path))
                if template_file:
                    r = handler.render_string(template_file,
                                              TableName=Utils.capitalize(table_name), table_name=table_name,
                                              exclude_columns=exclude_columns,
                                              search_columns=search_columns,
                                              order_columns=order_columns,
                                              freeze_columns=freeze_columns,
                                              columns=cols, primary_key=pk, option=params
                                              )
                    txt = r.decode('utf-8')
                    file_name = Utils.mk_temp_file()
                    Utils.text_write(file_name, [txt])
                    Utils.text_file_compact(file_name)
                    handler.render_json(success=True, data=Utils.text_read(file_name))
                else:
                    handler.render_json(success=True, data="")
        else:
            handler.render_json(success=False)
